import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
import argparse
import torch.optim
from tqdm import tqdm
import os
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

def train(net, train_loader, optimizer, criterion, writer, args, epoch, index_num):
    net.train()
    train_tqdm = tqdm(train_loader, desc="Epoch " + str(epoch))
    for index, (inputs, labels) in enumerate(train_tqdm):
        optimizer.zero_grad()
        outputs = net(inputs.to(args.device))
        loss = criterion(outputs, labels.to(args.device))
        loss.backward()
        optimizer.step()
        writer.add_scalar("loss/train", loss, index_num)
        index_num = index_num + 1
        train_tqdm.set_postfix({"loss": "%.3g" % loss.item()})

def test(net, test_loader, criterion, writer, args, epoch, loss_vector, accuracy_vector):
    net.eval()
    val_loss, correct = 0, 0
    for index, (data, target) in enumerate(test_loader):
        data = data.to(args.device)
        target = target.to(args.device)
        output = net(data)
        val_loss += criterion(output, target.to(args.device)).data.item()
        pred = output.data.max(1)[1]
        correct += pred.eq(target.data).cpu().sum()
    val_loss /= len(test_loader)
    loss_vector.append(val_loss)
    writer.add_scalar("loss/validation", val_loss, epoch)
    accuracy = 100. * correct.to(torch.float32) / len(test_loader.dataset)
    accuracy_vector.append(accuracy)
    writer.add_scalar("accuracy/validation", accuracy, epoch)

    print("***** Eval results *****")
    print('epoch: {}, Validation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(
        epoch, val_loss, correct, len(test_loader.dataset), accuracy))

    return correct,val_loss